/*
* LearnThread.java 
* 
* Copyright (C) 2009 Aalborg University
*
* contact:
* jaeger@cs.aau.dk   http://www.cs.aau.dk/~jaeger/Primula.html
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
*/

package RBNLearning;

import RBNExceptions.RBNCompatibilityException;
import RBNpackage.*;
import RBNgui.*;
import RBNutilities.*;
import RBNinference.*;
import java.util.*;

import javax.swing.*;

import myio.StringOps;

import java.text.*;

public class LearnThread extends GGThread {
	
	Primula myprimula;
	LearnModule myLearnModule;
	RelData data;
	ParameterTableModel parammodel;
	JTable parametertable;
	JTextField numrestartsfield;
	DecimalFormat timeformat = new DecimalFormat("#.##");

	
	public LearnThread(Primula mypr,
			RelData d, 
			ParameterTableModel parmod,
			JTable partab,
			JTextField nrest,
			LearnModule mylm){
		myprimula = mypr;
		myLearnModule = mylm;
		data = d;
		parammodel = parmod;
		parametertable = partab;
		numrestartsfield = nrest;
	}
	
	public void run()
	{
		if (data != null){
			
				double timestart,timediff,timeperrs;
				
				
				/* Determines the parameters contained in the rbn model. */
				String[] rbnparameters = myprimula.getRBN().parameters();
				String[] parameternumrels = myprimula.getParamNumRels();
				String[] parameters = new String[0];
				
				
				/* Construct the parameters corresponding to ground numrel atoms */
				String[] nrparams = new String[0];
				RelDataForOneInput rdoi = data.caseAt(0);
				RelStruc A = rdoi.inputDomain();
				
				
				for (int i=0;i<parameternumrels.length;i++){
						Vector<String[]> alltuples = A.allTrue(parameternumrels[i],A);
						String[] nextparams = new String[alltuples.size()];
						for (int k=0;k<nextparams.length;k++)
							nextparams[k]=parameternumrels[i]+StringOps.arrayToString(alltuples.elementAt(k),"(",")");	
						nrparams = rbnutilities.arraymerge(nrparams,nextparams);
				}
				parameters = rbnutilities.arrayConcatenate(rbnparameters, nrparams);
				Arrays.sort(parameters);
				
				parammodel.setParameters(parameters);
				parammodel.fireTableDataChanged();
				parametertable.updateUI();
				// boolean computeLikOnly = (parameters.length == 0);
				
				numrestartsfield.setText("" );
				double[] paramvals = new double[parameters.length+1];
				
				int numblocks = myLearnModule.getNumblocks(); 
				
				String[][] paramblocks = new String[numblocks][];
				int blocklength = parameters.length/numblocks;
				
				for (int i=0;i<numblocks-1;i++){
					paramblocks[i]=new String[blocklength];
					for (int j=0;j<blocklength;j++)
						paramblocks[i][j]=parameters[i*blocklength+j];
				}
				paramblocks[numblocks-1]=new String[parameters.length-(numblocks-1)*blocklength];
				for (int j=0;j<parameters.length-(numblocks-1)*blocklength;j++)
					paramblocks[numblocks-1][j]=parameters[(numblocks-1)*blocklength+j];
				

				timestart = System.currentTimeMillis();
				GradientGraph gg = null;
				
				if (numblocks==1) /* Can build one GG once and for all */
					gg = buildGG(parameters,true);
			
				

				/* Current best log-likelihoods represented as pairs of 
				 * doubles (for use with SmallDouble methods)
				 */
				double currentbestlik = Double.NEGATIVE_INFINITY;
				double newlik;
				/* The sum of likelihood values obtained in several restarts.
				 * Used for pure likelihood computation only (computeLikOnly = true)
				 */
				double[] liksum = new double[2];
				
				/* rest is the number of completed restarts */
				int rest = 0;
				double[] results;
				
				while (!isstopped() && (rest < myLearnModule.getRestarts() 
						|| myLearnModule.getRestarts() == -1)){
					
						System.out.println("***** RESTART **********");
						results = doOneRestart(gg,A,parameternumrels,parameters,blocklength,paramblocks,rest==0);
						
						newlik = results[results.length-1];
						System.out.println("Likelihood: " + newlik);
						
						if (newlik > currentbestlik){
							currentbestlik = newlik;
							for (int i=0;i<parameters.length;i++)
								paramvals[i] = results[i];
							paramvals[paramvals.length-1]=results[results.length-1];
							parammodel.setEstimates(paramvals);
						}
						rest++;
						numrestartsfield.setText(""+rest);					
						parametertable.updateUI();
				}
				
				
				

				
				timediff = System.currentTimeMillis()-timestart;
				timeperrs = timediff/(1000*(rest+1));
				myprimula.showMessageThis("done. Time per restart: " + timeformat.format(timeperrs) +"s");
			
		}	
	}

private double[] doOneRestart(GradientGraph gg,
								RelStruc A,
								String[] parameternumrels,
								String[] parameters,
								int blocksize,
								String[][] paramblocks,
								Boolean isfirstrestart){
	myprimula.getRBN().setRandomParameterVals();
	if (myLearnModule.ggrandominit())
		A.setRandom(parameternumrels);
	if (gg != null){
		gg.setParametersFromAandRBN();
		return gg.learnParameters(this);
	}
	else{
		double[] paramvals = new double[parameters.length+4];
		Boolean terminate = false;
		double[] blockresult = new double[0];
		double lastlik = Double.NEGATIVE_INFINITY;
		double currlik;
		Boolean isfirstloop = true;
		while (!terminate && !isstopped()){
			for (int i=0;i<paramblocks.length;i++){
				gg = buildGG(paramblocks[i],isfirstrestart && isfirstloop);
				blockresult = gg.learnParameters(this);
				myprimula.setParameters(paramblocks[i], 
										Arrays.copyOfRange(blockresult,0,blockresult.length-4));
				
				/* Insert learned block of values in full array*/
				for (int j=0; j<paramblocks[i].length; j++)
					paramvals[i*blocksize+j]=blockresult[j];
				System.out.println("blocklik " + i + " : " + blockresult[blockresult.length-1]);
			}
			currlik = blockresult[blockresult.length-1];
			System.out.println("currlik: " + currlik);
			if (lastlik/currlik < 1.01)
				terminate = true; 
			lastlik = currlik;
			isfirstloop = false;
		}
		System.out.println();
		for (int i=0;i<4;i++)
			paramvals[parameters.length+i]=blockresult[blockresult.length-4+i];
		return paramvals;
	}
}

private GradientGraph buildGG(String[] parameters,Boolean showInfoInPrimula){
	GradientGraph gg = null;
	if (showInfoInPrimula)
		myprimula.showMessageThis("Building Gradient Graph ...");
	double timestart=System.currentTimeMillis();
	try{
		gg = new GradientGraph(myprimula,
			data,
			parameters,
			myLearnModule, 
			null,
			GradientGraph.LEARNMODE,
			showInfoInPrimula);
	}
	catch (RBNCompatibilityException ex){System.out.println(ex);}
	double timediff = (System.currentTimeMillis()-timestart)/1000;
	
	if (showInfoInPrimula)
		myprimula.showMessageThis("Construction time: " + timeformat.format(timediff) +"s");
	return gg;
}


}
